Python Data Science Handbook

Chapter 4 - Visualization with Matplotlib

In [3]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

# Change mathplotlib style
plt.style.use('seaborn-whitegrid') 

# Static images of the plots are embedded in the notebook
%matplotlib inline 

Basic Plot Example

In [2]:
x = np.linspace(start=0, stop=10, num=100) # evenly spaced numbers

fig = plt.figure()
plt.plot(x, np.sin(x), '-')
plt.plot(x, np.cos(x), '--');

Saving Figures to File

In [3]:
# Save Image
fig.savefig('my_figure.png')

# Load Image
from IPython.display import Image
Image('my_figure.png')
Out[3]:
In [4]:
# Supported figure canvas objects
fig.canvas.get_supported_filetypes()
Out[4]:
{'ps': 'Postscript',
 'eps': 'Encapsulated Postscript',
 'pdf': 'Portable Document Format',
 'pgf': 'PGF code for LaTeX',
 'png': 'Portable Network Graphics',
 'raw': 'Raw RGBA bitmap',
 'rgba': 'Raw RGBA bitmap',
 'svg': 'Scalable Vector Graphics',
 'svgz': 'Scalable Vector Graphics',
 'jpg': 'Joint Photographic Experts Group',
 'jpeg': 'Joint Photographic Experts Group',
 'tif': 'Tagged Image File Format',
 'tiff': 'Tagged Image File Format'}

Two Interfaces for the Price of One

In [5]:
# MATLAB-style interface

plt.figure() # Create a plot figure

# Create the first of two panels and set current axis
plt.subplot(2, 1, 1) # (rows, columns, panel number)
plt.plot(x, np.sin(x))

# Create the second panel and set current axis
plt.subplot(2, 1, 2)
plt.plot(x, np.cos(x));

# Interface is 'stateful', it keeps track of the current figure and axes
# plt.gcf(): get current figure
# plt.gca(): get current axes
In [6]:
# Object oriented interface

# First create a grid of plots
# ax will be an array of two Axes objects 
fig, ax = plt.subplots(2)

# Call plot() method on the appropriate object
ax[0].plot(x, np.sin(x))
ax[1].plot(x, np.cos(x));

Simple Line Plots

In [7]:
fig = plt.figure()
ax = plt.axes()

x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x));
In [8]:
# Over-plotting multiple lines

plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x));

Adjusting the Plot: Line Colors and Styles

In [9]:
plt.plot(x, np.sin(x - 0), color='blue') # specify color by name
plt.plot(x, np.sin(x - 1), color='g') # short color code (rgbcmyk)
plt.plot(x, np.sin(x - 2), color='0.75') # Grayscale between 0 and 1
plt.plot(x, np.sin(x - 3), color='#FFDD44') # Hex code (RRGGBB from 00 to FF)
plt.plot(x, np.sin(x - 4), color=(1.0,0.2,0.3)); # RGB tuple, values 0 and 1 plt.plot(x, np.sin(x - 5), color='chartreuse'); # all HTML color names supported
In [10]:
plt.plot(x, x + 0, linestyle='solid')
plt.plot(x, x + 1, linestyle='dashed')
plt.plot(x, x + 2, linestyle='dashdot')
plt.plot(x, x + 3, linestyle='dotted');

# For short, you can use the following codes: 
plt.plot(x, x + 4, linestyle='-') # solid
plt.plot(x, x + 5, linestyle='--') # dashed
plt.plot(x, x + 6, linestyle='-.') # dashdot
plt.plot(x, x + 7, linestyle=':'); # dotted

Adjusting the Plot: Axes Limits

In [11]:
# Example of setting axis limits

plt.plot(x, np.sin(x))
plt.xlim(-1, 11)
plt.ylim(-1.5, 1.5);
In [12]:
# Example of reversing the y-axis
plt.plot(x, np.sin(x))
# Reversing the arguments
plt.xlim(10, 0)
plt.ylim(1.2, -1.2);
In [13]:
# Change both axis in a single call
plt.plot(x, np.sin(x))
plt.axis([-1, 11, -1.5, 1.5]);
In [14]:
# Automatically tight the graph axis
plt.plot(x, np.sin(x))
plt.axis('tight')
Out[14]:
(-0.5, 10.5, -1.0999971452300779, 1.099999549246729)
In [15]:
# Example of an “equal” layout, with units matched to the output resolution
# Equal aspect ratio
plt.plot(x, np.sin(x))
plt.axis('equal');

Labeling Plots

In [16]:
# Examples of axis labels and title

plt.plot(x, np.sin(x))
plt.title("A Sine Curve")
plt.xlabel("x")
plt.ylabel("sin(x)");
In [17]:
# Plot legend example

plt.plot(x, np.sin(x), '-g', label='sin(x)')
plt.plot(x, np.cos(x), ':b', label='cos(x)')
plt.axis('equal')

plt.legend();
In [18]:
# Mapping between MATLAB and object-oriented style

# plt.xlabel() → ax.set_xlabel() 
# plt.ylabel() → ax.set_ylabel()
# plt.xlim() → ax.set_xlim()
# plt.ylim() → ax.set_ylim()
# plt.title() → ax.set_title()

# Object-oriented interface to plotting
ax = plt.axes()
ax.plot(x, np.sin(x))
ax.set(xlim=(0, 10), ylim=(-2, 2),
    xlabel='x', ylabel='sin(x)',
    title='A Simple Plot');

Simple Scatter Plots

Scatter Plots with plt.plot

In [19]:
# Scatter plot example

x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o', color='black');
In [20]:
# Demonstration of point numbers

rng = np.random.RandomState(0)
for marker in ['o', '.', ',', 'x', '+', 'v', '^', '<', '>', 's', 'd']:
    plt.plot(rng.rand(5), rng.rand(5), marker,
        label="marker='{0}'".format(marker))
    plt.legend(numpoints=1)
    plt.xlim(0, 1.8);
In [21]:
# Combining line and point markers

plt.plot(x, y, '-ok'); # line (-), circle marker (o), black (k)
In [22]:
# Customizing line and point numbers

plt.plot(x, y, '-p', color='gray',
    markersize=15, linewidth=4,
    markerfacecolor='white',
    markeredgecolor='gray',
    markeredgewidth=2)
plt.ylim(-1.2, 1.2);

Scatter Plots with plt.scatter

A second, more powerful method of creating scatter plots is the plt.scatter()

In [23]:
plt.scatter(x, y, marker='o');
In [24]:
# Changing size, color, and transparency in scatter points

rng = np.random.RandomState(0)
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)

plt.scatter(x, y, c=colors, s=sizes, alpha=0.3,
           cmap='viridis')
plt.colorbar(); # show color scale
In [25]:
# Using point properties to encode features of the Iris data

from sklearn.datasets import load_iris 

iris = load_iris()
features = iris.data.T

plt.scatter(features[0], features[1], alpha=0.2,
    s=100*features[3], c=iris.target, cmap='viridis')
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1]);

# plt.plot should be preferred over plt.scatter for large datasets. 
# The plt.scatter determines the appearence for each individual point

Visualizing Errors

Basic Errorbars

In [26]:
# An errorbar example

x = np.linspace(0, 10, 50) 
dy = 0.8
y = np.sin(x) + dy * np.random.randn(50)

plt.errorbar(x, y, yerr=dy, fmt='.k');
In [27]:
# Customizing errorbars aesthetic

plt.errorbar(x, y, yerr=dy, fmt='o', color='black',
    ecolor='lightgray', elinewidth=3, capsize=0);

# I often find it helpful, especially in crowded plots, to make
# the errorbars lighter than the points themselves

Continuous Errors

In [28]:
# GaussianProcess is deprecated

from sklearn import GaussianProcess

# define the model and draw some data
model = lambda x: x * np.sin(x) 
xdata = np.array([1, 3, 5, 6, 8]) 
ydata = model(xdata)

# Compute the Gaussian process fit
gp = GaussianProcess(corr='cubic', theta0=1e-2, thetaL=1e-4, thetaU=1E-1,
    random_start=100)
gp.fit(xdata[:, np.newaxis], ydata)
xfit = np.linspace(0, 10, 1000)
yfit, MSE = gp.predict(xfit[:, np.newaxis], eval_MSE=True)
dyfit = 2 * np.sqrt(MSE) # 2*sigma ~ 95% confidence region

# Visualize the result plt.plot(xdata, ydata, 'or')
# Representing continuous uncertainty with filled regions
plt.plot(xfit, yfit, '-', color='gray')
plt.fill_between(xfit, yfit - dyfit, yfit + dyfit,
                color='gray', alpha=0.2)
plt.xlim(0, 10);
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-28-98643e823da0> in <module>
      1 # GaussianProcess is deprecated
      2 
----> 3 from sklearn import GaussianProcess
      4 
      5 # define the model and draw some data

ImportError: cannot import name 'GaussianProcess' from 'sklearn' (/usr/local/lib/python3.7/site-packages/sklearn/__init__.py)

Density and Contour Plots

In [29]:
# Sometimes is useful display three-dimensional data in
# two dimensions using contours or color-coded regions
# plt.contour for contour plots
# plt.contourf for filled contour plots
# plt.imshow for showing images

Visualizing a Three-Dimensional Function

In [30]:
# Contour plot using a function z = f(x,y)
def f(x, y):
    return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)

# The x and y values represent positions on the plot, and
# the z values will be represented by the contour levels
x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 40)

# the most straightforward way to prepare such data is to use the
# np.meshgrid function, which builds two-dimensional grids from
# one-dimensional arrays
X, Y = np.meshgrid(x, y) 

Z=f(X,Y)

# Plot
contour = plt.contour(X, Y, Z, colors='black');
plt.clabel(contour, inline=True, fontsize=8)

# By default when a single color is used, negative values are
# represented by dashed lines, and positive values by solid lines.
Out[30]:
<a list of 35 text.Text objects>
In [31]:
# Visualizing three-dimensional data with colored contours

# color-code the lines by specifying a colormap with the cmap argument
# lines to be drawn—20 equally spaced intervals within the data range
plt.contour(X, Y, Z, 30, cmap='RdGy');
In [32]:
# Visualizing three-dimensional data with filled contours
plt.contourf(X, Y, Z, 20, cmap='RdGy')
plt.colorbar(); # Creates an index

# The colorbar makes it clear that the black regions are “peaks,
# ” while the red regions are “valleys.”
In [33]:
# Representing three-dimensional data as an image

# The number of contours to a very high number, but this results
# in a rather inefficient plot:
# Matplotlib must render a new polygon for each step in the level.
# A better way to handle this is to use the plt.imshow() function,
# which inter‐ prets a two-dimensional grid of data as an image.

plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower', cmap='RdGy')
plt.colorbar()
plt.axis(aspect='image');

# plt.imshow() doesn’t accept an x and y grid, so you must manually
# specify the extent [xmin, xmax, ymin, ymax] of the image on the plot.
# plt.imshow() by default follows the standard image array definition
# where the origin is in the upper left, not in the lower left as in
# most contour plots. This must be changed when showing gridded data.
# plt.imshow() will automatically adjust the axis aspect ratio to match
# the input data; you can change this by setting, for example,
# plt.axis(aspect='image') to make x and y units match.
In [34]:
# Labeled contours on top of an image

# Partially transparent background image
# Over-plot contours with labels on the contours themselves

contours = plt.contour(X, Y, Z, 3, colors='black')
plt.clabel(contours, inline=True, fontsize=8)
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower',
    cmap='RdGy', alpha=0.5)
plt.colorbar();

Histograms, Binnings, and Density

In [35]:
# A simple histogram

data = np.random.randn(1000)
plt.hist(data);
In [36]:
# More advanced histogram
# Density histogram

plt.hist(data, bins=30, density=True, alpha=0.5,
    histtype='stepfilled', color='steelblue',
    edgecolor='none');
In [37]:
# Over-plotting multiple histograms
# histtype='stepfilled' along with some transparency alpha 
# useful when comparing histograms of several distributions

x1 = np.random.normal(0, 0.8, 1000)
x2 = np.random.normal(-2, 1, 1000)
x3 = np.random.normal(3, 2, 1000)

kwargs = dict(histtype='stepfilled', alpha=0.3, normed=True, bins=40)

plt.hist(x1, **kwargs)
plt.hist(x2, **kwargs)
plt.hist(x3, **kwargs);
/usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:11: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  # This is added back by InteractiveShellApp.init_path()
/usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:12: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  if sys.path[0] == '':
/usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:13: MatplotlibDeprecationWarning: 
The 'normed' kwarg was deprecated in Matplotlib 2.1 and will be removed in 3.1. Use 'density' instead.
  del sys.path[0]
In [38]:
# Count the number of points in each bin
counts, bin_edges = np.histogram(data, bins=5) 
print(counts)
[  6 130 479 346  39]

Two-Dimensional Histograms and Binnings

In [39]:
# we can create histograms in two dimensions by
# dividing points among two- dimensional bins
mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 10000).T

plt.hist2d: Two-dimensional histogram

In [40]:
# A two-dimensional histogram with plt.hist2d

plt.hist2d(x, y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('counts in bin')

plt.hexbin: Hexagonal binnings

In [41]:
# A two-dimensional histogram with plt.hexbin
# Another natural shape for such a tessellation is the regular hexagon.
plt.hexbin(x, y, gridsize=30, cmap='Blues')
cb = plt.colorbar(label='count in bin')

Kernel density estimation

In [42]:
# A kernel density representation of a distribution

# Another common method of evaluating densities
# in multiple dimensions is kernel density estimation (KDE).

from scipy.stats import gaussian_kde

# fit an array of size [Ndim, Nsamples]
data = np.vstack([x, y])
kde = gaussian_kde(data)

# evaluate on a regular grid
xgrid = np.linspace(-3.5, 3.5, 40)
ygrid = np.linspace(-6, 6, 40)
Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
Z = kde.evaluate(np.vstack([Xgrid.ravel(), Ygrid.ravel()]))
# Plot the result as an image
plt.imshow(Z.reshape(Xgrid.shape),
   origin='lower', aspect='auto',
   extent=[-3.5, 3.5, -6, 6],
   cmap='Blues')
cb = plt.colorbar()
cb.set_label("density")

Customizing Plot Legends

In [43]:
# A default plot legend

x = np.linspace(0, 10, 1000)

fig, ax = plt.subplots()
ax.plot(x, np.sin(x), '-b', label='Sine')
ax.plot(x, np.cos(x), '--r', label='Cosine')
ax.axis('equal')
leg = ax.legend();
In [44]:
# A customized plot legend
ax.legend(loc='upper left', frameon=False)
fig
Out[44]:
In [45]:
# Specify the number of columns in the legend
# A two-column plot legend
ax.legend(frameon=False, loc='lower center', ncol=2)
fig
Out[45]:
In [46]:
# rounded box (fancybox) or add a shadow, change the transparency
# (alpha value) of the frame, or change the padding around the text 
ax.legend(fancybox=True, framealpha=1, shadow=True, borderpad=1)
fig
Out[46]:

Choosing Elements for the Legend

In [47]:
# Customization of legend elements

# fine-tune which elements and labels appear in the
# legend by using the objects returned by plot commands

y = np.sin(x[:, np.newaxis] + np.pi * np.arange(0, 2, 0.5))

lines = plt.plot(x, y)
# lines is a list of plt.Line2D instances
plt.legend(lines[:2], ['first', 'second']);
In [48]:
# Alternative method 
# the legend ignores all elements without a label attribute set

plt.plot(x, y[:, 0], label='first')
plt.plot(x, y[:, 1], label='second')
plt.plot(x, y[:, 2:])
plt.legend(framealpha=1, frameon=True);

Legend for Size of Points

In [49]:
# Location, geographic size, and population of California cities

cities = pd.read_csv('data/california_cities.csv')

# Extract the data we're interested in
lat, lon = cities['latd'], cities['longd']
population, area = cities['population_total'], cities['area_total_km2']

# Scatter the points, using size and color but no label
plt.scatter(lon, lat, label=None,
    c=np.log10(population), cmap='viridis',
    s=area, linewidth=0, alpha=0.5)
plt.axis(aspect='equal')
plt.xlabel('longitude')
plt.ylabel('latitude')
plt.colorbar(label='log$_{10}$(population)')
plt.clim(3, 7)

# Here we create a legend:
# we'll plot empty lists with the desired size and label for area in [100, 300, 500]:
plt.scatter([], [], c='k', alpha=0.3,
    s=area, label=str(area) + ' km$^2$')
plt.legend(scatterpoints=1, frameon=False, # THIS IS NOT WORKING
    labelspacing=1, title='City Area')
plt.title('California Cities: Area and Population');

Multiple Legends

If you try to create a second legend using plt.legend() or ax.legend(), it will simply override the first one. We can work around this by creating a new legend artist from scratch, and then using the lower-level ax.add_artist() method to manually add the second artist to the plot.

In [50]:
fig, ax = plt.subplots()
lines = []
styles = ['-', '--', '-.', ':']
x = np.linspace(0, 10, 1000)

for i in range(4):
    lines += ax.plot(x, np.sin(x - i * np.pi / 2),
        styles[i], color='black')
ax.axis('equal')

# specify the lines and labels of the first legend
ax.legend(lines[:2], ['line A', 'line B'],
    loc='upper right', frameon=False)

# Create the second legend and add the artist manually.
from matplotlib.legend import Legend
leg = Legend(ax, lines[2:], ['line C', 'line D'],
    loc='lower right', frameon=False)
ax.add_artist(leg);

Customizing Colorbars

In [51]:
# the simplest colorbar can be created with the plt.colorbar function
x = np.linspace(0, 10, 1000)
I = np.sin(x) * np.cos(x[:, np.newaxis])
plt.imshow(I)
plt.colorbar();
In [52]:
plt.imshow(I, cmap='gray');

All the available colormaps are in the plt.cm namespace: plt.cm.TAB

  • Sequential colormaps: These consist of one continuous sequence of colors (e.g., binary or viridis);

  • Divergent colormaps: These usually contain two distinct colors, which show positive and negative deviations from a mean (e.g., RdBu or PuOr);

  • Qualitative colormaps: These mix colors with no particular sequence (e.g., rainbow or jet);

In [53]:
from matplotlib.colors import LinearSegmentedColormap

def grayscale_cmap(cmap):
    """Return a grayscale version of the given colormap""" 
    cmap = plt.cm.get_cmap(cmap)
    colors = cmap(np.arange(cmap.N))
    
    # convert RGBA to perceived grayscale luminance
    # cf. http://alienryderflex.com/hsp.html
    RGB_weight = [0.299, 0.587, 0.114]
    luminance = np.sqrt(np.dot(colors[:, :3] ** 2, RGB_weight))
    colors[:, :3] = luminance[:, np.newaxis]
    
    return LinearSegmentedColormap.from_list(cmap.name + "_gray", colors, cmap.N)

def view_colormap(cmap):
    """Plot a colormap with its grayscale equivalent""" 
    cmap = plt.cm.get_cmap(cmap)
    colors = cmap(np.arange(cmap.N))
    cmap = grayscale_cmap(cmap)
    grayscale = cmap(np.arange(cmap.N))
    fig, ax = plt.subplots(2, figsize=(6, 2),
        subplot_kw=dict(xticks=[], yticks=[]))
    ax[0].imshow([colors], extent=[0, 10, 0, 1])
    ax[1].imshow([grayscale], extent=[0, 10, 0, 1])

# The jet colormap and its uneven luminance scale
view_colormap('jet')
In [54]:
# The cubehelix colormap and its luminance
view_colormap('cubehelix')
In [55]:
view_colormap('RdBu')
In [56]:
# Specifying colormap extensions
# we can narrow the color limits and indicate the out-of-bounds
# values with a triangular arrow at the top and bottom by setting
# the extend property.

# make noise in 1% of the image pixels
speckles = (np.random.random(I.shape) < 0.01)
I[speckles] = np.random.normal(0, 3, np.count_nonzero(speckles))

plt.figure(figsize=(10, 3.5))
plt.subplot(1, 2, 1) # row, col, id
plt.imshow(I, cmap='RdBu')
plt.colorbar()
plt.subplot(1, 2, 2) # row, col, id
plt.imshow(I, cmap='RdBu')
plt.colorbar(extend='both')
plt.clim(-1, 1);

# Notice that in the left panel, the default color limits respond
# to the noisy pixels, and the range of the noise completely washes
# out the pattern we are interested in.

Discrete colorbars

Colormaps are by default continuous, but sometimes you’d like to represent discrete values. The easiest way to do this is to use the plt.cm.get_cmap() function, and pass the name of a suitable colormap along with the number of desired bins

In [57]:
plt.imshow(I, cmap=plt.cm.get_cmap('Blues', 6))
plt.colorbar()
plt.clim(-1, 1);

Example: Handwritten Digits

In [58]:
# Sample of handwritten digit data
# load images of the digits 0 through 5 and visualize several of them 

from sklearn.datasets import load_digits

digits = load_digits(n_class=6)
fig, ax = plt.subplots(8, 8, figsize=(6, 6))  # 8x8 grid

for i, axi in enumerate(ax.flat):
    axi.imshow(digits.images[i], cmap='binary')
    axi.set(xticks=[], yticks=[])
In [59]:
# Manifold embedding of handwritten digit pixels

# project the digits into 2 dimensions using IsoMap 
from sklearn.manifold import Isomap

iso = Isomap(n_components=2)
projection = iso.fit_transform(digits.data)

# discrete colormap to view the results, setting the ticks
# and clim to improve the aesthetics of the resulting colorbar

# plot the results
plt.scatter(projection[:, 0], projection[:, 1], lw=0.1,
    c=digits.target, cmap=plt.cm.get_cmap('cubehelix', 6))
plt.colorbar(ticks=range(6), label='digit value')
plt.clim(-0.5, 5.5)

Multiple Subplots

Four routines for creating subplots in Matplotlib

In [60]:
# Example of an inset axes
ax1 = plt.axes() # standard axes
ax2 = plt.axes([0.65, 0.65, 0.2, 0.2])
In [61]:
# Vertically stacked axes example

fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4],
    xticklabels=[], ylim=(-1.2, 1.2))
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4],
    ylim=(-1.2, 1.2))
x = np.linspace(0, 10)
ax1.plot(np.sin(x))
ax2.plot(np.cos(x));

plt.subplot: Simple Grids of Subplots

In [62]:
# A plt.subplot() example
for i in range(1, 7): 
    plt.subplot(2, 3, i) # rows, cols, id
    plt.text(0.5, 0.5, str((2, 3, i)),
        fontsize=18, ha='center')
In [63]:
# Adjust the spacing between the subplots
fig = plt.figure() 
fig.subplots_adjust(hspace=0.4, wspace=0.4) 
for i in range(1, 7):
    ax = fig.add_subplot(2, 3, i)
    ax.text(0.5, 0.5, str((2, 3, i)),
        fontsize=18, ha='center')

plt.subplotS: The Whole Grid in One Go

In [64]:
# Shared x and y axis in plt.subplots()

# optional keywords sharex and sharey allows to
# specify the relationships between different axes
fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')
In [65]:
# axes are in a two-dimensional array, indexed by [row, col] 
for i in range(2):
    for j in range(3):
        ax[i, j].text(0.5, 0.5, str((i, j)),
            fontsize=18, ha='center')
fig

# In comparison to plt.subplot(), plt.subplots() is more
# consistent with Python’s conventional 0-based indexing.
Out[65]:

plt.GridSpec: More Complicated Arrangements

To go beyond a regular grid to subplots that span multiple rows and columns, plt.GridSpec() is the best tool. It is simply a convenient interface that is recognized by the plt.subplot() command.

In [66]:
# Irregular subplots with plt.GridSpec
grid = plt.GridSpec(2, 3, wspace=0.4, hspace=0.3)

plt.subplot(grid[0, 0])
plt.subplot(grid[0, 1:])
plt.subplot(grid[1, :2])
plt.subplot(grid[1, 2]);
In [67]:
# Visualizing multidimensional distributions with plt.GridSpec
# There is a own plotting API in the Seaborn package to build a plot like this.

# Create some normally distributed data mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 3000).T

# Set up the axes with gridspec
fig = plt.figure(figsize=(6, 6))
grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
main_ax = fig.add_subplot(grid[:-1, 1:])
y_hist = fig.add_subplot(grid[:-1, 0], xticklabels=[], sharey=main_ax)
x_hist = fig.add_subplot(grid[-1, 1:], yticklabels=[], sharex=main_ax)

# scatter points on the main axes
main_ax.plot(x, y, 'ok', markersize=3, alpha=0.2)

# histogram on the attached axes
x_hist.hist(x, 40, histtype='stepfilled',
    orientation='vertical', color='gray')
x_hist.invert_yaxis()
y_hist.hist(y, 40, histtype='stepfilled',
    orientation='horizontal', color='gray')
y_hist.invert_xaxis()

Text and Annotation

Creating a good visualization involves guiding the reader so that the figure tells a story. In some cases, this story can be told in an entirely visual manner, without the need for added text, but in others, small textual cues and labels are necessary. Perhaps the most basic types of annotations you will use are axes labels and titles, but the options go beyond this.

Example: Effect of Holidays on US Births

In [68]:
# Average daily births by date

births = pd.read_csv('data/births.csv')

quartiles = np.percentile(births['births'], [25, 50, 75])
mu, sig = quartiles[1], 0.74 * (quartiles[2] - quartiles[0])
births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')

births['day'] = births['day'].astype(int)

births.index = pd.to_datetime(10000 * births.year +
    100 * births.month +
    births.day, format='%Y%m%d')

births_by_date = births.pivot_table('births',
    [births.index.month, births.index.day])

births_by_date.index = [pd.datetime(2012, month, day) 
    for (month, day) in births_by_date.index]
    
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);
In [69]:
# Add labels to the plot
style = dict(size=10, color='gray')
ax.text('2012-1-1', 3950, "New Year's Day", **style)
ax.text('2012-7-4', 4250, "Independence Day", ha='center', **style)
ax.text('2012-9-4', 4850, "Labor Day", ha='center', **style)
ax.text('2012-10-31', 4600, "Halloween", ha='right', **style)
ax.text('2012-11-25', 4450, "Thanksgiving", ha='center', **style)
ax.text('2012-12-25', 3850, "Christmas ", ha='right', **style)

# Label the axes
ax.set(title='USA births by day of year (1969-1988)',
      ylabel='average daily births')

# Format the x axis with centered month labels
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));

fig
Out[69]:

Transforms and Text Position

Matplotlib has a well-developed set of tools that it uses inter‐ nally to perform them (the tools can be explored in the matplotlib.transforms sub‐ module). The average user rarely needs to worry about the details of these transforms.

There are three predefined transforms:

  • ax.transData: Transform associated with data coordinates;
  • ax.transAxes: Transform associated with the axes (in units of axes dimensions);
  • fig.transFigure: Transform associated with the figure (in units of figure dimensions);
In [70]:
# Comparing Matplotlib’s coordinate systems

fig, ax = plt.subplots(facecolor='lightgray')
ax.axis([0, 10, 0, 10])

# transform=ax.transData is the default, but we'll specify it anyway
ax.text(1, 5, ". Data: (1, 5)", transform=ax.transData)
ax.text(0.5, 0.1, ". Axes: (0.5, 0.1)", transform=ax.transAxes)
ax.text(0.2, 0.2, ". Figure: (0.2, 0.2)", transform=fig.transFigure);
In [71]:
# Comparing Matplotlib’s coordinate systems

# if we change the axes limits, it is only the transData
# coordinates that will be affected, while the others remain stationary
ax.set_xlim(0, 2)
ax.set_ylim(-6, 6)
fig
Out[71]:

Arrows and Annotation

In [72]:
fig, ax = plt.subplots()
x = np.linspace(0, 20, 1000)
ax.plot(x, np.cos(x))
ax.axis('equal')
ax.annotate('local maximum', xy=(6.28, 1), xytext=(10, 4),
    arrowprops=dict(facecolor='black', shrink=0.05))
ax.annotate('local minimum', xy=(5 * np.pi, -1), xytext=(2, -6),
    arrowprops=dict(facecolor='black', width=1.5,
        connectionstyle="angle3,angleA=0,angleB=-90"));
In [73]:
# Some of the arrowprops possibilities

# Annotated average birth rates by day

fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax)
# Add labels to the plot
ax.annotate("New Year's Day", xy=('2012-1-1', 4100),  xycoords='data',
    xytext=(50, -30), textcoords='offset points',
    arrowprops=dict(arrowstyle="->",
        connectionstyle="arc3,rad=-0.2"))
ax.annotate("Independence Day", xy=('2012-7-4', 4250),  xycoords='data',
    bbox=dict(boxstyle="round", fc="none", ec="gray"),
        xytext=(10, -40), textcoords='offset points', ha='center',
        arrowprops=dict(arrowstyle="->"))
ax.annotate('Labor Day', xy=('2012-9-4', 4850), xycoords='data', ha='center',
    xytext=(0, -20), textcoords='offset points')
ax.annotate('', xy=('2012-9-1', 4850), xytext=('2012-9-7', 4850),
    xycoords='data', textcoords='data',
    arrowprops={'arrowstyle': '|-|,widthA=0.2,widthB=0.2', })
ax.annotate('Halloween', xy=('2012-10-31', 4600),  xycoords='data',
    xytext=(-80, -40), textcoords='offset points',
    arrowprops=dict(arrowstyle="fancy",
        fc="0.6", ec="none",
        connectionstyle="angle3,angleA=0,angleB=-90"))
ax.annotate('Thanksgiving', xy=('2012-11-25', 4500),  xycoords='data',
    xytext=(-120, -60), textcoords='offset points',
    bbox=dict(boxstyle="round4,pad=.5", fc="0.9"),
    arrowprops=dict(arrowstyle="->",
        connectionstyle="angle,angleA=0,angleB=80,rad=20"))
ax.annotate('Christmas', xy=('2012-12-25', 3850),  xycoords='data',
     xytext=(-30, 0), textcoords='offset points',
     size=13, ha='right', va="center",
     bbox=dict(boxstyle="round", alpha=0.1),
     arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1));

# Label the axes
ax.set(title='USA births by day of year (1969-1988)',
       ylabel='average daily births')

# Format the x axis with centered month labels
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));
ax.set_ylim(3600, 5400);

Customizing Ticks

Each axes has attributes xaxis and yaxis, which in turn have attributes that contain all the properties of the lines, ticks, and labels that make up the axes.

Major and Minor Ticks

In [74]:
# Example of logarithmic scales and labels
ax = plt.axes(xscale='log', yscale='log')

print(ax.xaxis.get_major_locator()) 
print(ax.xaxis.get_minor_locator())

print(ax.xaxis.get_major_formatter())
print(ax.xaxis.get_minor_formatter())

# Both major and minor tick labels have their
# locations specified by a LogLocator (which makes sense
# for a logarithmic plot)
<matplotlib.ticker.LogLocator object at 0x11c3a3910>
<matplotlib.ticker.LogLocator object at 0x11bd72710>
<matplotlib.ticker.LogFormatterSciNotation object at 0x11c3a3bd0>
<matplotlib.ticker.LogFormatterSciNotation object at 0x11c399150>

Hiding Ticks or Labels

Perhaps the most common tick/label formatting operation is the act of hiding ticks or labels. We can do this using plt.NullLocator() and plt.NullFormatter()

In [75]:
# Plot with hidden tick labels (x-axis) and hidden ticks (y-axis)
ax = plt.axes()
ax.plot(np.random.rand(50))
ax.yaxis.set_major_locator(plt.NullLocator())
ax.xaxis.set_major_formatter(plt.NullFormatter())
In [76]:
# Hiding ticks within image plots
# e.g., display images

fig, ax = plt.subplots(5, 5, figsize=(5, 5))
fig.subplots_adjust(hspace=0, wspace=0)
    
# Get some face data from scikit-learn
from sklearn.datasets import fetch_olivetti_faces

faces = fetch_olivetti_faces().images

for i in range(5): 
    for j in range(5):
        # Hiding ticks within image plots
        ax[i, j].xaxis.set_major_locator(plt.NullLocator())
        ax[i, j].yaxis.set_major_locator(plt.NullLocator())
        ax[i, j].imshow(faces[10 * i + j], cmap="bone")

Reducing or Increasing the Number of Ticks

In [77]:
# A default plot with ticks

fig, ax = plt.subplots(4, 4, sharex=True, sharey=True)
In [78]:
# Customizing the number of ticks 

# plt.MaxNLocator(): maximum number of ticks that will be displayed.
# For every axis, set the x and y major locator 
for axi in ax.flat:
    axi.xaxis.set_major_locator(plt.MaxNLocator(3))
    axi.yaxis.set_major_locator(plt.MaxNLocator(3))
fig
Out[78]:

Fancy Tick Formats

Matplotlib’s default tick formatting can leave a lot to be desired; it works well as a broad default, but sometimes you’d like to do something more.

In [79]:
# A default plot with integer ticks

# Plot a sine and cosine curve 
fig, ax = plt.subplots()
x = np.linspace(0, 3 * np.pi, 1000)
ax.plot(x, np.sin(x), lw=3, label='Sine')
ax.plot(x, np.cos(x), lw=3, label='Cosine')

# Set up grid, legend, and limits
ax.grid(True)
ax.legend(frameon=False)
ax.axis('equal')
ax.set_xlim(0, 3 * np.pi);
In [80]:
# Ticks at multiples of pi/2

# MultipleLocator(): locates ticks at a multiple of the number you provide.
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
fig
Out[80]:
In [81]:
# Ticks with custom labels

def format_func(value, tick_number): # find number of multiples of pi/2
    N = int(np.round(2 * value / np.pi)) 
    if N==0:
        return "0"
    elif N==1:
        return r"$\pi/2$"
    elif N==2: 
        return r"$\pi$"
    elif N%2>0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)
    
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
fig
Out[81]:

Summary of Formatters and Locators

In [82]:
# Locator class      ---           Description
# ------------------------------------------------------------------------------------
# NullLocator        --- No ticks
# FixedLocator       --- Tick locations are fixed
# IndexLocator       --- Locator for index plots (e.g., where x = range(len(y)))
# LinearLocator      --- Evenly spaced ticks from min to max
# LogLocator         --- Logarithmically ticks from min to max
# MultipleLocator    --- Ticks and range are a multiple of base
# MaxNLocator        --- Finds up to a max number of ticks at nice locations
# AutoLocator        --- (Default) MaxNLocator with simple defaults
# AutoMinorLocator   --- Locator for minor ticks
# NullFormatter      --- No labels on the ticks
# IndexFormatter     --- Set the strings from a list of labels
# FixedFormatter     --- Set the strings manually for the labels
# FuncFormatter      --- User-defined function sets the labels
# FormatStrFormatter --- Use a format string for each value
# ScalarFormatter    --- (Default) Formatter for scalar values
# LogFormatter       --- Default formatter for log axes

Customizing Matplotlib: Configurations and Stylesheets

Plot Customization by Hand

In [83]:
# A histogram in Matplotlib’s default style
plt.style.use('classic')
%matplotlib inline

x = np.random.randn(1000)
plt.hist(x);
In [84]:
# A histogram with manual customizations

# use a gray background
ax = plt.axes(facecolor='#E6E6E6')
ax.set_axisbelow(True)

# draw solid white grid lines
plt.grid(color='w', linestyle='solid')

# hide axis spines
for spine in ax.spines.values(): 
    spine.set_visible(False)
    
# hide top and right ticks
ax.xaxis.tick_bottom()
ax.yaxis.tick_left()

# lighten ticks and labels
ax.tick_params(colors='gray', direction='out') 
for tick in ax.get_xticklabels():
    tick.set_color('gray')
for tick in ax.get_yticklabels():
    tick.set_color('gray')
    
# control face and edge color of histogram
ax.hist(x, edgecolor='#E6E6E6', color='#EE6666');

Changing the Defaults: rcParams

You can adjust this configuration at any time using the plt.rc convenience routine.

In [85]:
# We’ll start by saving a copy of the current rcParams dictionary,
# so we can easily reset these changes in the current session:
IPython_default = plt.rcParams.copy()
In [86]:
# Change settings
from matplotlib import cycler

colors = cycler('color',
    ['#EE6666', '#3388BB', '#9988DD',
    '#EECC55', '#88BB44', '#FFBBBB'])
plt.rc('axes', facecolor='#E6E6E6', edgecolor='none',
    axisbelow=True, grid=True, prop_cycle=colors)
plt.rc('grid', color='w', linestyle='solid')
plt.rc('xtick', direction='out', color='gray')
plt.rc('ytick', direction='out', color='gray')
plt.rc('patch', edgecolor='#E6E6E6')
plt.rc('lines', linewidth=2)
In [87]:
# A customized histogram using rc settings
plt.hist(x);
In [88]:
# A line plot with customized styles
for i in range(4):
    plt.plot(np.random.rand(10))

Stylesheets

In the style module; These stylesheets are formatted similarly to the .matplotlibrc files mentioned earlier, but must be named with a .mplstyle extension.

In [89]:
plt.style.available[:5]
Out[89]:
['seaborn-dark',
 'seaborn-darkgrid',
 'seaborn-ticks',
 'fivethirtyeight',
 'seaborn-whitegrid']
In [90]:
# Change a style temporarily

# with plt.style.context('stylename'): 
    # make_a_plot()
In [91]:
def hist_and_lines(): 
    np.random.seed(0)
    fig, ax = plt.subplots(1, 2, figsize=(11, 4)) 
    ax[0].hist(np.random.randn(1000))
    for i in range(3):
        ax[1].plot(np.random.rand(10))
        ax[1].legend(['a', 'b', 'c'], loc='lower left')

Default style

In [92]:
# reset rcParams 
plt.rcParams.update(IPython_default);

# Matplotlib’s default style
hist_and_lines()
/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning: 
The examples.directory rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2. In the future, examples will be found relative to the 'datapath' directory.
  self[key] = other[key]
/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
  self[key] = other[key]
/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning: 
The text.latex.unicode rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2.
  self[key] = other[key]
/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning: 
The verbose.fileo rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
  self[key] = other[key]
/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning: 
The verbose.level rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
  self[key] = other[key]

FiveThirtyEight style

In [93]:
with plt.style.context('fivethirtyeight'): 
    hist_and_lines()

ggplot

In [94]:
with plt.style.context('ggplot'):
    hist_and_lines()

Bayesian Methods for Hackers style

In [95]:
with plt.style.context('bmh'):
    hist_and_lines()

Dark background

In [96]:
with plt.style.context('dark_background'):
    hist_and_lines()

Grayscale

e.g., publication that does not accept color figures

In [97]:
with plt.style.context('grayscale'):
    hist_and_lines()

Seaborn style

In [98]:
with plt.style.context('seaborn'):
    hist_and_lines()
In [99]:
# Set back to the initial style sheet

plt.style.use('seaborn-whitegrid') 
%matplotlib inline 

Three-Dimensional Plotting in Matplotlib

In [100]:
# Turn the 3d plots interactive

#%matplotlib notebook
#%matplotlib notebook

#import matplotlib as mpl
#import matplotlib.pyplot as plt

#%matplotlib notebook
#%matplotlib notebook
In [101]:
# We enable three-dimensional plots by importing the mplot3d 
# toolkit, included with the main Matplotlib installation
from mpl_toolkits import mplot3d

# An empty three-dimensional axes
fig = plt.figure()
ax = plt.axes(projection='3d')

# recall that to use interactive figures, you can use %matplotlib notebook 
# rather than %matplotlib inline when running this code.

Three-Dimensional Points and Lines

The most basic three-dimensional plot is a line or scatter plot created from sets of (x, y, z) triples. ax.plot3D and ax.scatter3D functions

In [102]:
# Points and lines in three dimensions
ax = plt.axes(projection='3d')

# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');

Three-Dimensional Contour Plots

In [103]:
def f(x, y):
    return np.sin(np.sqrt(x ** 2 + y ** 2))

x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)
X, Y = np.meshgrid(x, y)
print(type(x), type(X))
Z=f(X,Y)

# A three-dimensional contour plot
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
<class 'numpy.ndarray'> <class 'numpy.ndarray'>
In [104]:
# Adjusting the view angle for a three-dimensional plot

# An elevation of 60 degrees (i.e., 60 degrees above the x-y plane)
# An azimuth of 35 degrees (i.e., rotated 35 degrees counter-clockwise about the z-axis)
ax.view_init(60, 35)
fig
Out[104]:

Wireframes and Surface Plots

In [105]:
# A wireframe plot

# These take a grid of values and project it onto the 
# specified three-dimensional surface.
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(X, Y, Z, color='black')
ax.set_title('wireframe');
In [106]:
# A three-dimensional surface plot

# A surface plot is like a wireframe plot, but each face
# of the wireframe is a filled poly‐ gon.
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
    cmap='viridis', edgecolor='none')
ax.set_title('surface');
In [107]:
# A polar surface plot

# Example of creating a partial polar grid, which when used
# with the surface3D plot can give us a slice into the function
# we’re visualizing 
r = np.linspace(0, 6, 20)
theta = np.linspace(-0.9 * np.pi, 0.8 * np.pi, 40)
r, theta = np.meshgrid(r, theta)

X = r * np.sin(theta) 
Y = r * np.cos(theta)
Z=f(X,Y)

ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
    cmap='viridis', edgecolor='none');

Surface Triangulations

For some applications, the evenly sampled grids required by the preceding routines are overly restrictive and inconvenient. In these situations, the triangulation-based plots can be very useful.

In [108]:
# A three-dimensional sampled surface

theta = 2 * np.pi * np.random.random(1000)
r = 6 * np.random.random(1000)
x = np.ravel(r * np.sin(theta))
y = np.ravel(r * np.cos(theta))
z=f(x,y)

ax = plt.axes(projection='3d')
ax.scatter(x, y, z, c=z, cmap='viridis', linewidth=0.5);
In [109]:
# A triangulated surface plot

# ax.plot_trisurf: creates a surface by first finding a
# set of triangles formed between adjacent points
ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z,
    cmap='viridis', edgecolor='none');

Example: Visualizing a Möbius strip

A Möbius strip is similar to a strip of paper glued into a loop with a half-twist. It has only a single side

In [110]:
theta = np.linspace(0, 2 * np.pi, 30)
w = np.linspace(-0.25, 0.25, 8)
w, theta = np.meshgrid(w, theta)

phi = 0.5 * theta

# radius in x-y plane 
r=1+w*np.cos(phi)

x = np.ravel(r * np.cos(theta))
y = np.ravel(r * np.sin(theta))
z = np.ravel(w * np.sin(phi))

# triangulate in the underlying parameterization 
from matplotlib.tri import Triangulation

tri = Triangulation(np.ravel(w), np.ravel(theta))
ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z, triangles=tri.triangles,
    cmap='viridis', linewidths=0.2);
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(-1, 1);

Geographic Data with Basemap

In [5]:
from mpl_toolkits.basemap import Basemap

# A “bluemarble” projection of the Earth
plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None, lat_0=50, lon_0=-100)
m.bluemarble(scale=0.5);
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [6]:
# Plotting data and labels on the map
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
    width=8E6, height=8E6,
    lat_0=45, lon_0=-100,)
m.etopo(scale=0.5, alpha=0.5)

# Map (long, lat) to (x, y) for plotting
x, y = m(-122.3, 47.6)
plt.plot(x, y, 'ok', markersize=5)
plt.text(x, y, ' Seattle', fontsize=12);
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Map Projections

In [10]:
from itertools import chain

def draw_map(m, scale=0.2):
    # draw a shaded-relief image
    m.shadedrelief(scale=scale)
    
    # lats and longs are returned as a dictionary
    lats = m.drawparallels(np.linspace(-90, 90, 13))
    lons = m.drawmeridians(np.linspace(-180, 180, 13))
    # keys contain the plt.Line2D instances
    lat_lines = chain(*(tup[1][0] for tup in lats.items())) 
    lon_lines = chain(*(tup[1][0] for tup in lons.items())) 
    all_lines = chain(lat_lines, lon_lines)
    # cycle through these lines and set the desired style
    for line in all_lines:
        line.set(linestyle='-', alpha=0.3, color='w')
In [11]:
# Cylindrical equal-area projection


# latitude (lat) and longitude (lon)
# lower-left corner (llcrnr) and upper-right corner (urcrnr)
# in units of degrees.
fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='cyl', resolution=None,
    llcrnrlat=-90, urcrnrlat=90, 
    llcrnrlon=-180, urcrnrlon=180, )
draw_map(m)
In [12]:
# Pseudo-cylindrical projections
# The Molleweide projection

# The Mollweide projection (projection='moll') is one common
# example of this, in which all meridians are elliptical arcs
fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='moll', resolution=None,
    lat_0=0, lon_0=0)

draw_map(m)
In [15]:
# Perspective projections
## The orthographic projection
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None,
    lat_0=50, lon_0=0)
draw_map(m);
In [17]:
# Conic projections
## The Albers equal-area projection
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
    lon_0=0, lat_0=50, lat_1=45, lat_2=55,
    width=1.6E7, height=1.2E7)
draw_map(m)

Plotting Data on Maps

Example: California Cities

In [21]:
# Scatter plot over a map background

cities = pd.read_csv('data/california_cities.csv')

# Extract the data we're interested in
lat = cities['latd'].values
lon = cities['longd'].values
population = cities['population_total'].values
area = cities['area_total_km2'].values

# 1. Draw the map background 
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution='h',
    lat_0=37.5, lon_0=-119,
    width=1E6, height=1.2E6)
m.shadedrelief()
m.drawcoastlines(color='gray')
m.drawcountries(color='gray')
m.drawstates(color='gray')

# 2. scatter city data, with color reflecting population 
# and size reflecting area
m.scatter(lon, lat, latlon=True,
    c=np.log10(population), s=area,
    cmap='Reds', alpha=0.5)

# 3. create colorbar and legend
plt.colorbar(label=r'$\log_{10}({\rm population})$')
plt.clim(3, 7)

# make legend with dummy points
for a in [100, 300, 500]:
    plt.scatter([], [], c='k', alpha=0.5, s=a,
        label=str(a) + ' km$^2$')
    plt.legend(scatterpoints=1, frameon=False,
        labelspacing=1, loc='lower left');

Visualization with Seaborn

There are several valid complaints about Matplotlib that often come up:

  • Prior to version 2.0, Matplotlib’s defaults are not exactly the best choices. It was based off of MATLAB circa 1999, and this often shows.
  • Matplotlib’s API is relatively low level. Doing sophisticated statistical visualiza‐ tion is possible, but often requires a lot of boilerplate code.
  • Matplotlib predated Pandas by more than a decade, and thus is not designed for use with Pandas DataFrames. In order to visualize data from a Pandas DataFrame, you must extract each Series and often concatenate them together into the right format. It would be nicer to have a plotting library that can intelligently use the DataFrame labels in a plot.

An answer to these problems is Seaborn.

In [26]:
# Create some data
rng = np.random.RandomState(0)
x = np.linspace(0, 10, 500)
y = np.cumsum(rng.randn(500, 6), 0)
In [31]:
# Plot the data with Matplotlib defaults 
# Data in Matplotlib’s default style
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');
In [33]:
import seaborn as sns
sns.set()

# seaborn it can also overwrite Matplotlib’s default parameters 
# and in turn get even simple Matplotlib scripts to produce 
# vastly superior output.

# same plotting code as above! 
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');

Exploring Seaborn Plots

Histograms, KDE, and densities

Often in statistical data visualization, all you want is to plot histograms and joint dis‐ tributions of variables.

In [39]:
# Histograms for visualizing distributions

data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])
for col in 'xy':
    plt.hist(data[col], density=True, alpha=0.5)
In [41]:
# Kernel density estimates for visualizing distributions

for col in 'xy': 
    sns.kdeplot(data[col], shade=True)
In [42]:
# Kernel density estimates (KDE) and histograms plotted together
sns.distplot(data['x'])
sns.distplot(data['y']);
In [48]:
# A two-dimensional kernel density plot

# If we pass two vectors to kdeplot, 
# we will get a two-dimensional visualization of the data
sns.kdeplot(data['x'], data['y']);
In [49]:
# A joint distribution plot with a two-dimensional kernel density estimate

# We can see the joint distribution and the marginal 
# distributions together using sns.jointplot.
with sns.axes_style('white'): 
    sns.jointplot("x", "y", data, kind='kde');
In [50]:
# A joint distribution plot with a hexagonal bin representation
with sns.axes_style('white'): 
    sns.jointplot("x", "y", data, kind='hex')

Pair Plot

In [54]:
# A pair plot showing the relationships between four variables
iris = sns.load_dataset("iris")    
sns.pairplot(iris, hue='species', height=2.5);

Faceted histograms

Sometimes the best way to view data is via histograms of subsets. Seaborn’s FacetGrid makes this extremely simple. We’ll take a look at some data that shows the amount that restaurant staff receive in tips based on various indicator data.

In [60]:
# An example of a faceted histogram

tips = sns.load_dataset('tips')
tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True)
grid.map(plt.hist, "tip_pct", bins=np.linspace(0, 40, 15));

Factor plots

Factor plots can be useful for this kind of visualization as well. This allows you to view the distribution of a parameter within bins defined by any other parameter

In [71]:
# An example of a factor plot, comparing distributions given 
# various discrete factors
with sns.axes_style(style='ticks'):
    g = sns.catplot("day", "total_bill", "sex", data=tips, kind="box")
    g.set_axis_labels("Day", "Total Bill");

Joint distributions

In [64]:
# A joint distribution plot
with sns.axes_style('white'):
    sns.jointplot("total_bill", "tip", data=tips, kind='hex')
In [65]:
# The joint plot can even do some automatic kernel density estimation and regression
sns.jointplot("total_bill", "tip", data=tips, kind='reg');

Bar plots

Time series can be plotted with sns.factorplot

In [69]:
planets = sns.load_dataset('planets')
planets.head(2)
Out[69]:
method number orbital_period mass distance year
0 Radial Velocity 1 269.300 7.10 77.40 2006
1 Radial Velocity 1 874.774 2.21 56.95 2008
In [73]:
# A histogram as a special case of a factor plot
with sns.axes_style('white'):
    g = sns.catplot("year", data=planets, aspect=2,
        kind="count", color='steelblue')
    g.set_xticklabels(step=5)
In [80]:
#### Number of planets discovered by year and type
with sns.axes_style('white'):
    g = sns.catplot("year", data=planets, aspect=4.0, kind='count',
        hue='method', order=range(2001, 2015))
    g.set_ylabels('Number of Planets Discovered')

Example: Exploring Marathon Finishing Times

In [154]:
data = pd.read_csv('data/marathon-data.csv')
data.head()
Out[154]:
age gender split final
0 33 M 01:05:38 02:08:51
1 32 M 01:06:26 02:09:28
2 31 M 01:06:49 02:10:42
3 38 M 01:06:16 02:13:45
4 31 M 01:06:32 02:13:59
In [155]:
data.dtypes
Out[155]:
age        int64
gender    object
split     object
final     object
dtype: object
In [157]:
# There is no datetime in pandas without date, so it is converted to an object with dt.time
data['split'] = pd.to_datetime(data['split'], format='%H:%M:%S')
data['final'] = pd.to_datetime(data['final'], format='%H:%M:%S')
data.head()
Out[157]:
age gender split final
0 33 M 1900-01-01 01:05:38 1900-01-01 02:08:51
1 32 M 1900-01-01 01:06:26 1900-01-01 02:09:28
2 31 M 1900-01-01 01:06:49 1900-01-01 02:10:42
3 38 M 1900-01-01 01:06:16 1900-01-01 02:13:45
4 31 M 1900-01-01 01:06:32 1900-01-01 02:13:59
In [158]:
data.dtypes
Out[158]:
age                int64
gender            object
split     datetime64[ns]
final     datetime64[ns]
dtype: object
In [159]:
def convert_to_seconds(time):
    return time.dt.hour * 3600 + time.dt.minute * 60 + time.dt.second

data['split_sec'] = convert_to_seconds(data['split'])
data['final_sec'] = convert_to_seconds(data['final'])
data.head()
Out[159]:
age gender split final split_sec final_sec
0 33 M 1900-01-01 01:05:38 1900-01-01 02:08:51 3938 7731
1 32 M 1900-01-01 01:06:26 1900-01-01 02:09:28 3986 7768
2 31 M 1900-01-01 01:06:49 1900-01-01 02:10:42 4009 7842
3 38 M 1900-01-01 01:06:16 1900-01-01 02:13:45 3976 8025
4 31 M 1900-01-01 01:06:32 1900-01-01 02:13:59 3992 8039
In [160]:
with sns.axes_style('white'):
    g = sns.jointplot("split_sec", "final_sec", data, kind='hex')
    g.ax_joint.plot(np.linspace(4000, 16000),
        np.linspace(8000, 32000), ':k')
    
# The relationship between the split for the first half-marathon
# and the finishing time for the full marathon
In [161]:
# Let’s create another column in the data, the split fraction,
# which measures the degree to which each runner negative-splits
# or positive-splits the race:
data['split_frac'] = 1 - 2 * data['split_sec'] / data['final_sec']
data.head()
Out[161]:
age gender split final split_sec final_sec split_frac
0 33 M 1900-01-01 01:05:38 1900-01-01 02:08:51 3938 7731 -0.018756
1 32 M 1900-01-01 01:06:26 1900-01-01 02:09:28 3986 7768 -0.026262
2 31 M 1900-01-01 01:06:49 1900-01-01 02:10:42 4009 7842 -0.022443
3 38 M 1900-01-01 01:06:16 1900-01-01 02:13:45 3976 8025 0.009097
4 31 M 1900-01-01 01:06:32 1900-01-01 02:13:59 3992 8039 0.006842
In [163]:
# The distribution of split fractions; 0.0 indicates a runner
# who completed the first and second halves in identical times
sns.distplot(data['split_frac'], kde=False);
plt.axvline(0, color="k", linestyle="--");
In [164]:
# Out of nearly 40,000 participants, there were only
# 250 people who negative-split their marathon.
sum(data.split_frac < 0)
Out[164]:
251
In [167]:
# The relationship between quantities within the marathon dataset
# Check if there are any correlations
g = sns.PairGrid(data, vars=['age', 'split_sec', 'final_sec', 'split_frac'],
    hue='gender', palette='RdBu_r')
g.map(plt.scatter, alpha=0.1)
g.add_legend();
In [168]:
# The distribution of split fractions by gender
sns.kdeplot(data.split_frac[data.gender=='M'], label='men', shade=True)
sns.kdeplot(data.split_frac[data.gender=='W'], label='women', shade=True)
plt.xlabel('split_frac');
In [169]:
# A violin plot showing the split fraction by gender
sns.violinplot("gender", "split_frac", data=data,
    palette=["lightblue", "lightpink"]);
In [170]:
# A new column in the array that specifies the decade of age that each person is in
data['age_dec'] = data.age.map(lambda age: 10 * (age // 10)) 
data.head()
Out[170]:
age gender split final split_sec final_sec split_frac age_dec
0 33 M 1900-01-01 01:05:38 1900-01-01 02:08:51 3938 7731 -0.018756 30
1 32 M 1900-01-01 01:06:26 1900-01-01 02:09:28 3986 7768 -0.026262 30
2 31 M 1900-01-01 01:06:49 1900-01-01 02:10:42 4009 7842 -0.022443 30
3 38 M 1900-01-01 01:06:16 1900-01-01 02:13:45 3976 8025 0.009097 30
4 31 M 1900-01-01 01:06:32 1900-01-01 02:13:59 3992 8039 0.006842 30
In [171]:
men = (data.gender == 'M')
women = (data.gender == 'W')
In [176]:
# A violin plot showing the split fraction by gender and age
with sns.axes_style(style=None):
    sns.violinplot("age_dec", "split_frac",
        hue="gender", data=data,
    split=True, inner="quartile",
    palette=["lightblue", "lightpink"]);
In [173]:
# Also surprisingly, the 80-year-old women seem to outperform
# everyone in terms of their split time. This is probably due
# to the fact that we’re estimating the distribution from small
# numbers, as there are only a handful of runners in that range:
(data.age > 80).sum()
Out[173]:
7
In [174]:
# Split fraction versus finishing time by gender
# lmplot automatically fits a linear regression to the data
g = sns.lmplot('final_sec', 'split_frac', col='gender', data=data,
    markers=".", scatter_kws=dict(color='c'))
g.map(plt.axhline, y=0.1, color="k", ls=":");
In [ ]: